# Copyright 2021 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import os
from pprint import pprint

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from orion.client import cli as orion_cli

from solo.args.setup import parse_args_pretrain
from solo.methods import METHODS
from solo.utils.auto_resumer import AutoResumer


try:
    from solo.methods.dali import PretrainABC
except ImportError as e:
    print(e)
    _dali_avaliable = False
else:
    _dali_avaliable = True

try:
    from solo.utils.auto_umap import AutoUMAP
except ImportError:
    _umap_available = False
else:
    _umap_available = True

import types

from solo.utils.checkpointer import Checkpointer
from solo.utils.classification_dataloader import prepare_data as prepare_data_classification
from solo.utils.pretrain_dataloader import (
    prepare_dataloader,
    prepare_datasets,
    prepare_n_crop_transform,
    prepare_transform,
    ImagenetDataModule,
)


def main():
    args = parse_args_pretrain()
    seed_everything(args.seed)

    assert args.method in METHODS, f"Choose from {METHODS.keys()}"

    if args.num_large_crops != 2:
        assert args.method == "wmse"

    MethodClass = METHODS[args.method]
    if args.dali:
        assert (
            _dali_avaliable
        ), "Dali is not currently avaiable, please install it first with [dali]."
        MethodClass = types.new_class(f"Dali{MethodClass.__name__}", (PretrainABC, MethodClass))

    model = MethodClass(**args.__dict__)

    # pretrain dataloader
    if not args.dali:
        # asymmetric augmentations
        if args.unique_augs > 1:
            transform = [
                prepare_transform(args.dataset, **kwargs) for kwargs in args.transform_kwargs
            ]
        else:
            transform = [prepare_transform(args.dataset, **args.transform_kwargs)]

        transform = prepare_n_crop_transform(transform, num_crops_per_aug=args.num_crops_per_aug)
        if args.debug_augmentations:
            print("Transforms:")
            pprint(transform)

        if args.dataset != 'imagenet':
            train_dataset = prepare_datasets(
                args.dataset,
                transform,
                data_dir=args.data_dir,
                train_dir=args.train_dir,
                no_labels=args.no_labels,
            )
            train_loader = prepare_dataloader(
                train_dataset, batch_size=args.batch_size, num_workers=args.num_workers
            )

            # normal dataloader for when it is available
            if args.dataset == "custom" and (args.no_labels or args.val_dir is None):
                val_loader = None
            elif args.dataset in ["imagenet100", "imagenet"] and args.val_dir is None:
                val_loader = None
            else:
                _, val_loader = prepare_data_classification(
                    args.dataset,
                    data_dir=args.data_dir,
                    train_dir=args.train_dir,
                    val_dir=args.val_dir,
                    batch_size=args.batch_size,
                    num_workers=args.num_workers,
                )
        else:
            dm = ImagenetDataModule(args.data_dir, args.train_dir, args.val_dir, args.batch_size, args.num_workers, transform)

    callbacks = []

    # wandb logging
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    os.makedirs(os.path.join(args.checkpoint_dir, 'wandb'), exist_ok=True)
    if args.wandb:
        wandb_logger = WandbLogger(
            name=args.name,
            #  id=args.name,
            save_dir=str(args.checkpoint_dir),
            group=args.group,
            project=args.project,
            entity=args.entity,
            offline=args.offline,
        )
        wandb_logger.watch(model, log="gradients", log_freq=100)
        wandb_logger.log_hyperparams(args)

        # lr logging
        lr_monitor = LearningRateMonitor(logging_interval="epoch")
        callbacks.append(lr_monitor)

    if args.save_checkpoint:
        # save checkpoint on last epoch only
        ckpt = Checkpointer(
            args,
            logdir=args.checkpoint_dir,
            frequency=args.checkpoint_frequency,
        )
        callbacks.append(ckpt)

    if args.auto_umap:
        assert (
            _umap_available
        ), "UMAP is not currently avaiable, please install it first with [umap]."
        auto_umap = AutoUMAP(
            args,
            logdir=os.path.join(args.auto_umap_dir, args.method),
            frequency=args.auto_umap_frequency,
        )
        callbacks.append(auto_umap)

    # 1.7 will deprecate resume_from_checkpoint, but for the moment
    # the argument is the same, but we need to pass it as ckpt_path to trainer.fit
    ckpt_path = None
    if args.auto_resume and args.resume_from_checkpoint is None:
        auto_resumer = AutoResumer(
            checkpoint_dir=args.checkpoint_dir,
            max_hours=args.auto_resumer_max_hours,
        )
        resume_from_checkpoint = auto_resumer.find_checkpoint(args)
        if resume_from_checkpoint is not None:
            print(
                "Resuming from previous checkpoint that matches specifications:",
                f"'{resume_from_checkpoint}'",
            )
            ckpt_path = resume_from_checkpoint
    elif args.resume_from_checkpoint is not None:
        ckpt_path = args.resume_from_checkpoint
        del args.resume_from_checkpoint

    trainer = Trainer.from_argparse_args(
        args,
        logger=wandb_logger if args.wandb else None,
        callbacks=callbacks,
        enable_checkpointing=False,
    )

    #a = iter(train_loader)
    #data = next(a)

    #model.eval()

    #y = data[1][0][:1]

    #from fvcore.nn import FlopCountAnalysis
    ##from fvcore.nn import ActivationCountAnalysis
    #from decimal import Decimal
    #flops = FlopCountAnalysis(model, y)
    ##flops = ActivationCountAnalysis(model, y)
    #print(f'{Decimal(flops.total()):.2E}')
    #exit(0)

    try:
        if args.dali:
            trainer.fit(model, val_dataloaders=val_loader, ckpt_path=ckpt_path)
        else:
            if args.dataset != 'imagenet':
                trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
            else:
                trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)

    except RuntimeError:
        orion_cli.report_bad_trial()
        raise
    else:
        # Orion minimize the following objective
        obj = 100 - float(trainer.callback_metrics["val_acc1"])
        orion_cli.report_objective(obj)


if __name__ == "__main__":
    main()
